load("//bazel:api.bzl", "mojo_test")

package(default_visibility = ["//max:consumers"])

_DEPS = [
    "//max:internal_utils",
    "//max:linalg",
    "//max:quantization",
    "@mojo//:stdlib",
]

_FP8_FP4_TESTS = [
    "test_scaled_matmul.mojo",
    "test_scaled_fp8_quantization.mojo",
    "test_scaled_nvfp4_quantization.mojo",
    "test_nvfp4_block_scales_interleave.mojo",
]

_EXTRA_CONSTRAINTS = {
    "test_scaled_matmul.mojo": select({
        "//:h100_gpu": ["//:h100_gpu"],
        "//:b200_gpu": ["//:b200_gpu"],
        "//conditions:default": ["@platforms//:incompatible"],
    }),
    "test_scaled_fp8_quantization.mojo": select({
        "//:h100_gpu": ["//:h100_gpu"],
        "//:b200_gpu": ["//:b200_gpu"],
        "//conditions:default": ["@platforms//:incompatible"],
    }),
    "test_scaled_nvfp4_quantization.mojo": select({
        "//:b200_gpu": ["//:b200_gpu"],
        "//conditions:default": ["@platforms//:incompatible"],
    }),
    "test_nvfp4_block_scales_interleave.mojo": select({
        "//:b200_gpu": ["//:b200_gpu"],
        "//conditions:default": ["@platforms//:incompatible"],
    }),
}

[
    mojo_test(
        name = src + ".test",
        srcs = [src],
        exec_properties = {
            "test.resources:gpu-memory": "3",
        },
        tags = ["gpu"],
        target_compatible_with = ["//:has_gpu"] + _EXTRA_CONSTRAINTS.get(src, []),
        deps = _DEPS,
    )
    for src in _FP8_FP4_TESTS
]

[
    mojo_test(
        name = src + ".test",
        srcs = [src],
        tags = ["gpu"],
        target_compatible_with = ["//:has_gpu"],
        deps = _DEPS,
    )
    for src in glob(
        ["**/*.mojo"],
        exclude = [
            "test_multistage_gemm_q.mojo",
        ] + _FP8_FP4_TESTS,
    )
]

mojo_test(
    name = "test_multistage_gemm_q.mojo.test",
    srcs = ["test_multistage_gemm_q.mojo"],
    exec_properties = {
        "test.resources:gpu-memory": "3",
    },
    tags = ["gpu"],
    # FIXME: KERN-1377 and move this into the glob above
    target_compatible_with = ["//:nvidia_gpu"],
    deps = _DEPS,
)
